{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uJHywE_oL3j2"
      },
      "source": [
        "# Differentially private convolutional neural network on MNIST.\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/differentially_private_sgd.ipynb)\n",
        "\n",
        "A large portion of this code is forked from the differentially private SGD\n",
        "example in the [JAX repo](\n",
        "https://github.com/google/jax/blob/main/examples/differentially_private_sgd.py).\n",
        "\n",
        "[Differentially Private Stochastic Gradient Descent](https://arxiv.org/abs/1607.00133) requires clipping the per-example parameter\n",
        "gradients, which is non-trivial to implement efficiently for convolutional\n",
        "neural networks.  The JAX XLA compiler shines in this setting by optimizing the\n",
        "minibatch-vectorized computation for convolutional architectures. Train time\n",
        "takes a few seconds per epoch on a commodity GPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VaYIiCnjL3j3"
      },
      "outputs": [],
      "source": [
        "import warnings\n",
        "import dp_accounting\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from optax import contrib\n",
        "from optax import losses\n",
        "import optax\n",
        "from jax.example_libraries import stax\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Shows on which platform JAX is running.\n",
        "print(\"JAX running on\", jax.devices()[0].platform.upper())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t7Dn8L_Uw0Yb"
      },
      "source": [
        "This table contains hyperparameters and the corresponding expected test accuracy.\n",
        "\n",
        "\n",
        "| DPSGD  | LEARNING_RATE | NOISE_MULTIPLIER | L2_NORM_CLIP | BATCH_SIZE | NUM_EPOCHS | DELTA | FINAL TEST ACCURACY |\n",
        "| ------ | ------------- | ---------------- | ------------ | ---------- | ---------- | ----- | ------------------- |\n",
        "| False  | 0.1           | NA               | NA           | 256        | 20         | NA    | ~99%                |\n",
        "| True   | 0.25          | 1.3              | 1.5          | 256        | 15         | 1e-5  | ~95%                |\n",
        "| True   | 0.15          | 1.1              | 1.0          | 256        | 60         | 1e-5  | ~96.6%              |\n",
        "| True   | 0.25          | 0.7              | 1.5          | 256        | 45         | 1e-5  | ~97%                |"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jve2h810L3j3"
      },
      "outputs": [],
      "source": [
        "# Whether to use DP-SGD or vanilla SGD:\n",
        "DPSGD = True\n",
        "# Learning rate for the optimizer:\n",
        "LEARNING_RATE = 0.25\n",
        "# Noise multiplier for DP-SGD optimizer:\n",
        "NOISE_MULTIPLIER = 1.3\n",
        "# L2 norm clip:\n",
        "L2_NORM_CLIP = 1.5\n",
        "# Number of samples in each batch:\n",
        "BATCH_SIZE = 256\n",
        "# Number of epochs:\n",
        "NUM_EPOCHS = 15\n",
        "# Probability of information leakage:\n",
        "DELTA = 1e-5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iLGeV4y4DBkL"
      },
      "source": [
        "CIFAR10 and CIFAR100 are composed of 32x32 images with 3 channels (RGB). We'll now load the dataset using `tensorflow_datasets` and display a few of the first samples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zynvtk4wDBkL"
      },
      "outputs": [],
      "source": [
        "(train_loader, test_loader), info = tfds.load(\n",
        "    \"mnist\", split=[\"train\", \"test\"], as_supervised=True, with_info=True\n",
        ")\n",
        "\n",
        "min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)\n",
        "train_loader = train_loader.map(min_max_rgb)\n",
        "test_loader = test_loader.map(min_max_rgb)\n",
        "\n",
        "train_loader_batched = train_loader.shuffle(\n",
        "    buffer_size=10_000, reshuffle_each_iteration=True\n",
        ").batch(BATCH_SIZE, drop_remainder=True)\n",
        "\n",
        "NUM_EXAMPLES = info.splits[\"test\"].num_examples\n",
        "test_batch = next(test_loader.batch(NUM_EXAMPLES, drop_remainder=True).as_numpy_iterator())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o6In7oQ-0EhG"
      },
      "outputs": [],
      "source": [
        "init_random_params, predict = stax.serial(\n",
        "    stax.Conv(16, (8, 8), padding=\"SAME\", strides=(2, 2)),\n",
        "    stax.Relu,\n",
        "    stax.MaxPool((2, 2), (1, 1)),\n",
        "    stax.Conv(32, (4, 4), padding=\"VALID\", strides=(2, 2)),\n",
        "    stax.Relu,\n",
        "    stax.MaxPool((2, 2), (1, 1)),\n",
        "    stax.Flatten,\n",
        "    stax.Dense(32),\n",
        "    stax.Relu,\n",
        "    stax.Dense(10),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j2OUgc6J0Jsl"
      },
      "source": [
        "This function computes the privacy parameter epsilon for the given number of steps and probability of information leakage `DELTA`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "43177TofzuOa"
      },
      "outputs": [],
      "source": [
        "def compute_epsilon(steps):\n",
        "  if NUM_EXAMPLES * DELTA \u003e 1.:\n",
        "    warnings.warn(\"Your delta might be too high.\")\n",
        "  q = BATCH_SIZE / float(NUM_EXAMPLES)\n",
        "  orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))\n",
        "  accountant = dp_accounting.rdp.RdpAccountant(orders)\n",
        "  accountant.compose(dp_accounting.PoissonSampledDpEvent(\n",
        "      q, dp_accounting.GaussianDpEvent(NOISE_MULTIPLIER)), steps)\n",
        "  return accountant.get_epsilon(DELTA)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W9mPtPvB0D3X"
      },
      "outputs": [],
      "source": [
        "@jax.jit\n",
        "def loss_fn(params, batch):\n",
        "  images, labels = batch\n",
        "  logits = predict(params, images)\n",
        "  return losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean(), logits\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def test_step(params, batch):\n",
        "  images, labels = batch\n",
        "  logits = predict(params, images)\n",
        "  loss = losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean()\n",
        "  accuracy = (logits.argmax(1) == labels).mean()\n",
        "  return loss, accuracy * 100"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vOet-_860ysL"
      },
      "outputs": [],
      "source": [
        "if DPSGD:\n",
        "  tx = contrib.dpsgd(\n",
        "      learning_rate=LEARNING_RATE, l2_norm_clip=L2_NORM_CLIP,\n",
        "      noise_multiplier=NOISE_MULTIPLIER, seed=1337)\n",
        "else:\n",
        "  tx = optax.sgd(learning_rate=LEARNING_RATE)\n",
        "\n",
        "_, params = init_random_params(jax.random.PRNGKey(1337), (-1, 28, 28, 1))\n",
        "opt_state = tx.init(params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b-NmP7g01EdA"
      },
      "outputs": [],
      "source": [
        "@jax.jit\n",
        "def train_step(params, opt_state, batch):\n",
        "  grad_fn = jax.grad(loss_fn, has_aux=True)\n",
        "  if DPSGD:\n",
        "    # Inserts a dimension in axis 1 to use jax.vmap over the batch.\n",
        "    batch = jax.tree_util.tree_map(lambda x: x[:, None], batch)\n",
        "    # Uses jax.vmap across the batch to extract per-example gradients.\n",
        "    grad_fn = jax.vmap(grad_fn, in_axes=(None, 0))\n",
        "\n",
        "  grads, _ = grad_fn(params, batch)\n",
        "  updates, new_opt_state = tx.update(grads, opt_state, params)\n",
        "  new_params = optax.apply_updates(params, updates)\n",
        "  return new_params, new_opt_state"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QMl9dnbJ1OtQ"
      },
      "outputs": [],
      "source": [
        "accuracy, loss, epsilon = [], [], []\n",
        "\n",
        "for epoch in range(NUM_EPOCHS):\n",
        "  for batch in train_loader_batched.as_numpy_iterator():\n",
        "    params, opt_state = train_step(params, opt_state, batch)\n",
        "\n",
        "  # Evaluates test accuracy.\n",
        "  test_loss, test_acc = test_step(params, test_batch)\n",
        "  accuracy.append(test_acc)\n",
        "  loss.append(test_loss)\n",
        "  print(f\"Epoch {epoch + 1}/{NUM_EPOCHS}, test accuracy: {test_acc}\")\n",
        "\n",
        "  #\n",
        "  if DPSGD:\n",
        "    steps = (1 + epoch) * NUM_EXAMPLES // BATCH_SIZE\n",
        "    eps = compute_epsilon(steps)\n",
        "    epsilon.append(eps)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9nsV-9_b2qca"
      },
      "outputs": [],
      "source": [
        "if DPSGD:\n",
        "  _, axs = plt.subplots(ncols=3, figsize=(9, 3))\n",
        "else:\n",
        "  _, axs = plt.subplots(ncols=2, figsize=(6, 3))\n",
        "\n",
        "axs[0].plot(accuracy)\n",
        "axs[0].set_title(\"Test accuracy\")\n",
        "axs[1].plot(loss)\n",
        "axs[1].set_title(\"Test loss\")\n",
        "\n",
        "if DPSGD:\n",
        "  axs[2].plot(epsilon)\n",
        "  axs[2].set_title(\"Epsilon\")\n",
        "\n",
        "plt.tight_layout()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1ubOEWod3OPj"
      },
      "outputs": [],
      "source": [
        "print(f'Final accuracy: {accuracy[-1]}')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
